// Copyright 1998-2016 Glenn McIntosh
// licensed under the GNU General Public Licence version 3

// include files
#include "transform.h"
#include <functional>
#include <cmath>
#include <cassert>
#include <array>
#include <utility>

namespace math
{
namespace
{
// cosine/sine tables
template<typename R> struct CS
{
	R c, s;
	CS &operator+=(const CS &x) {return *this = {c*x.c - s*x.s, s*x.c + c*x.s};}
	CS operator+(const CS &x) const {return CS(*this) += x;}
	CS &operator-=(const CS &x) {return *this = {c*x.c + s*x.s, s*x.c - c*x.s};}
	CS operator-(const CS &x) const {return CS(*this) -= x;}
	CS reflect() const {return {s, c};}
};
constexpr real pi = ::acos(-1);
constexpr CS<real> cs(int i) {return {(real) cos(pi/(1<<i)), (real) sin(pi/(1<<i))};}
const CS<real> cs1[] = {cs(0),cs(1),cs(2),cs(3),cs(4),cs(5),cs(6),cs(7),cs(8),cs(9),cs(10),cs(11),cs(12),cs(13),cs(14),cs(15),cs(16),cs(17),cs(18),cs(19),cs(20),cs(21),cs(22),cs(23),cs(24),cs(25),cs(26),cs(27),cs(28),cs(29),cs(30),cs(31)};

// bit function
constexpr int ctz(int m) {return __builtin_ctz(m);}

// decimate in frequency or in time
//#define DFREQ
#define DTIME

// single layer Hartley computation
inline void radix2(Vector::iterator d, int m, std::array<CS<real>, 32>::iterator pCs)
{
	// butterfly functions
	auto cross0 = [d, m](int w)
	{
		real t0 = d[w], t1 = d[w+m];
		d[w] = t0+t1; d[w+m] = t0-t1;
	};
	auto cross1 = [d, m](int w, const CS<real> csw)
	{
		real t0 = d[w], t1 = d[m-w], t2 = d[w+m], t3 = d[m-w+m];
#		ifdef DFREQ
		real t4 = t0-t2, t5 = t1-t3;
		d[w] = t0+t2; d[m-w] = t1+t3; d[w+m] = csw.c*t4+csw.s*t5; d[m-w+m] = csw.s*t4-csw.c*t5;
#		endif
#		ifdef DTIME
		real t4 = csw.c*t2+csw.s*t3, t5 = csw.s*t2-csw.c*t3;
		d[w] = t0+t4; d[m-w] = t1+t5; d[w+m] = t0-t4; d[m-w+m] = t1-t5;
#		endif
	};

	// transform DC, Nyquist
	cross0(0);
	if (m&1) return;

	// transform half Nyquist
	cross0(m/2);
	if (m&2) return;

	// transform quarter Nyquist frequency components
	cross1(m/4, cs1[2]);
	if (m&4) return;

	// transform other frequency components
	auto csw = *pCs = {1, 0};
	for (int w = 1; true; ++w)
	{
		csw += cs1[ctz(m)];
		cross1(w, csw);
		cross1(m/2-w, csw.reflect());
	if (++w&m/4) break;
		csw = *pCs + cs1[ctz(m)-ctz(w)];
		if (w&-w&w>>1) --pCs;
		if (!(w&2)) *++pCs = csw;
		cross1(w, csw);
		cross1(m/2-w, csw.reflect());
	}
}
}

// fast Hartley transform
void hartley(Vector &d)
{
	const int n = d.size();
	assert((n&n-1) == 0);
	if (n == 1) return;

#	ifdef DFREQ
	for (int i = 0; i < n; i += 2)
	{
		CS<real> csStack[sizeof(int)*8];
		const auto dd = d.begin() + i;
		int in = i|n;
		if (in & 2) goto continue2;
		if (in & 4) goto continue4;
		if (in & 8) goto continue8;
		for (int m = (in&-in)>>1; m > 8; m >>= 1)
			radix2(dd, m, csStack);
		radix2(dd, 8, csStack); // unrolled
		continue8: radix2(dd, 4, csStack); // unrolled
		continue4: radix2(dd, 2, csStack); // unrolled
		continue2: radix2(dd, 1, csStack); // unrolled
	}
#	endif

	// bit-reversal permutation
	for (int m0 = 1, m1 = n>>1; m0 < m1; m0 <<= 1, m1 >>= 1)
		for (int i = 0; i < n; i = i+1+(m0|m1) & ~(m0|m1))
			std::swap(d[i+m0], d[i+m1]);

	// scaling
	real scale = 1/::sqrt(n);
	for (auto &t: d)
		t *= scale;

	// Hartley transform each subset
#	ifdef DTIME
	for (int i = n; i > 0;)
	{
		i -= 2;
		std::array<CS<real>, sizeof(int)*8> csStack;
		int in = i|n;
		const auto dd = d.begin() + i;
	   	radix2(dd, 1, csStack.begin()); // unrolled
		if (in & 2) continue;
	   	radix2(dd, 2, csStack.begin()); // unrolled
		if (in & 4) continue;
	   	radix2(dd, 4, csStack.begin()); // unrolled
		if (in & 8) continue;
	   	radix2(dd, 8, csStack.begin()); // unrolled
		for (int m = 16; !(in & m); m <<= 1)
			radix2(dd, m, csStack.begin());
	}
#endif
}

// 2 dimensional fast Hartley transform
void hartley(Vector2D &d)
{
	const int n = d.size();
	assert((n&n-1) == 0);
	assert(n == d[0].size());
	int i, j;

	// transform columns
	for (i = 0; i < n; ++i)
		hartley(d[i]);

	// transform rows
	for (i = 0; i < n; ++i)
		for (j = i+1; j < n; ++j)
			std::swap(d[i][j], d[j][i]);
	for (i = 0; i < n; ++i)
		hartley(d[i]);
	for (i = 0; i < n; ++i)
		for (j = i+1; j < n; ++j)
			std::swap(d[i][j], d[j][i]);

	// correct casine separability
	for (i = 1; i < n/2; ++i)
		for (j = 1; j < n/2; ++j)
		{
			real cas = ((d[i][j]+d[n-i][n-j]) - (d[i][n-j]+d[n-i][j]))/2;
			d[i][j] -= cas;
			d[i][n-j] += cas;
			d[n-i][j] += cas;
			d[n-i][n-j] -= cas;
		}
}

// auto-correlation
void autoCorrelate(Vector &d)
{
	const int n = d.size();
	assert((n&n-1) == 0);
	real scale = ::sqrt(n);

	// transform
	hartley(d);

	// multiply or divide in frequency domain
	d[0] *= d[0] * scale;
	for (int i = 1; i < n/2; ++i)
		d[i] = d[n-i] = (d[i]*d[i] + d[n-i]*d[n-i])/2 * scale;
	d[n/2] *= d[n/2] * scale;

	// inverse transform
	hartley(d);
}

// 2D convolution kernel
namespace
{
static void kernel(Vector2D &d, const Vector2D &f, const int n)
{
	// points unaliased in both domains
	d[0][0] = d[0][0]*f[0][0];
	d[0][n/2] = d[0][n/2]*f[0][n/2];
	d[n/2][0] = d[n/2][0]*f[n/2][0];
	d[n/2][n/2] = d[n/2][n/2]*f[n/2][n/2];

	// remainder
	for (int i = 1; i < n/2; ++i)
	{
		real c0, c1, c2, c3;

		// points unaliased in x domain
		c0 = d[0][i]*f[0][n-i] + d[0][n-i]*f[0][i];
		c2 = d[0][i]*f[0][i] - d[0][n-i]*f[0][n-i];
		d[0][i] = (c0+c2)/2;
		d[0][n-i] = (c0-c2)/2;
		c0 = d[n/2][i]*f[n/2][n-i] + d[n/2][n-i]*f[n/2][i];
		c2 = d[n/2][i]*f[n/2][i] - d[n/2][n-i]*f[n/2][n-i];
		d[n/2][i] = (c0+c2)/2;
		d[n/2][n-i] = (c0-c2)/2;

		// points unaliased in y domain
		c0 = d[i][0]*f[n-i][0] + d[n-i][0]*f[i][0];
		c3 = d[i][0]*f[i][0] - d[n-i][0]*f[n-i][0];
		d[i][0] = (c0+c3)/2;
		d[n-i][0] = (c0-c3)/2;
		c0 = d[i][n/2]*f[n-i][n/2] + d[n-i][n/2]*f[i][n/2];
		c3 = d[i][n/2]*f[i][n/2] - d[n-i][n/2]*f[n-i][n/2];
		d[i][n/2] = (c0+c3)/2;
		d[n-i][n/2] = (c0-c3)/2;

		// points aliased in both domains
		for (int j = 1; j < n/2; ++j)
		{
			c0 = d[i][j]*f[n-i][n-j] + d[n-i][n-j]*f[i][j] + d[n-i][j]*f[i][n-j] + d[i][n-j]*f[n-i][j];
			c1 = d[i][j]*f[i][j] + d[n-i][n-j]*f[n-i][n-j] - d[n-i][j]*f[n-i][j] - d[i][n-j]*f[i][n-j];
			c2 = d[i][j]*f[n-i][j] - d[n-i][n-j]*f[i][n-j] + d[n-i][j]*f[i][j] - d[i][n-j]*f[n-i][n-j];
			c3 = d[i][j]*f[i][n-j] - d[n-i][n-j]*f[n-i][j] - d[n-i][j]*f[n-i][n-j] + d[i][n-j]*f[i][j];
			d[i][j] = ((c0+c1)+(c2+c3))/4;
			d[n-i][n-j] = ((c0+c1)-(c2+c3))/4;
			d[i][n-j] = ((c0-c1)-(c2-c3))/4;
			d[n-i][j] = ((c0-c1)+(c2-c3))/4;
		}
	}
}
}

// cross correlation
void correlate(Vector2D &d, const Vector2D &r)
{
	const int n = d.size(), m = r.size();
	assert((n&n-1) == 0);
	assert(n == d[0].size());
	assert((m&1) == 1);
	assert(m = r[0].size());

	// align vector sizes
	Vector line(n, 0.0);
	Vector2D f(n, line);
	for (int i = 0; i < m; ++i)
		for (int j = 0; j < m; ++j)
			f[i <= m/2 ? i : i-m+n][j <= m/2 ? j : j-m+n] = r[i][j];

	// transform
	hartley(d);
	hartley(f);

	// rotate 180 degrees
	for (int i = 1; i < n/2; ++i)
	{
		std::swap(f[i][0], f[n-i][0]);
		std::swap(f[i][n/2], f[n-i][n/2]);
		std::swap(f[0][i], f[0][i]);
		std::swap(f[n/2][i], f[n/2][i]);
		for (int j = 1; j < n/2; ++j)
		{
			std::swap(f[i][j], f[n-i][n-j]);
			std::swap(f[i][n-j], f[n-i][j]);
		}
	}

	// multiply in frequency domain
	kernel(d, f, n);

	// inverse transform
	hartley(d);
}

// convolution
void convolve(Vector &d, const Vector &r)
{
	const int n = d.size(), m = r.size();
	assert((n&n-1) == 0);
	assert((m&1) == 1);
	real scale = ::sqrt(n);

	// align vector sizes
	Vector f(n, 0.0);
	for (int i = 0; i < m; ++i)
		f[i <= m/2 ? i : i-m+n] = r[i];

	// transform
	hartley(d);
	hartley(f);

	// multiply in frequency domain
	d[0] = d[0]*f[0] * scale;
	d[n/2] = d[n/2]*f[n/2] * scale;
	for (int i = 1; i < n/2; ++i)
	{
		real c0 = d[i]*f[i] - d[n-i]*f[n-i];
		real c1 = d[i]*f[n-i] + d[n-i]*f[i];
		d[i] = (c1+c0)/2 * scale;
		d[n-i] = (c1-c0)/2 * scale;
	}

	// inverse transform
	hartley(d);
}

// convolution
void convolve(Vector2D &d, const Vector2D &r)
{
	const int n = d.size(), m = r.size();
	assert((n&n-1) == 0);
	assert(n == d[0].size());
	assert((m&1) == 1);
	assert(m = r[0].size());

	// align vector sizes
	Vector line(n, 0.0);
	Vector2D f(n, line);
	for (int i = 0; i < m; ++i)
		for (int j = 0; j < m; ++j)
			f[i <= m/2 ? i : n-m+i][j <= m/2 ? j : n-m+j] = r[i][j];

	// transform
	hartley(d);
	hartley(f);

	// multiply in frequency domain
	kernel(d, f, n);

	// inverse transform
	hartley(d);
}

// Haar wavelet forward transform
void haarTransform0(Vector &d)
{
	const int n = d.size();
	assert((n&n-1) == 0);

	// at each scale for each data point
	for (int m = 1; m < n; m *= 2)
		for (int j = 0; j < n; j += m*2)
		{
			// lifting steps
			real x = d[j];
			real y = d[j+m];
			real s = x+y;
			real t = x-y;
			d[j] = s;
			d[j+m] = t;
		}
}

// Haar wavelet inverse transform
void haarTransform1(Vector &d)
{
	const int n = d.size();
	assert((n&n-1) == 0);

	// at each scale for each data point
	for (int m = n/2; m >= 1; m /= 2)
		for (int j = 0; j < n; j += m*2)
		{
			// lifting steps
			real s = d[j];
			real t = d[j+m];
			real x = (s+t)/2;
			real y = (s-t)/2;
			d[j] = x;
			d[j+m] = y;
		}
}

// Daubechies D4 wavelet forward transform
void d4Wavelet0(Vector &data, bool approx)
{
	const int n = data.size();
	assert((n&n-1) == 0);

	// at each scale
	for (int m = 1; m < n;)
	{
		// lifting steps
		for (int j = 0; j < n; j += m*2)
			data[j] += sqrt(3.F)*data[j+m];
		for (int j = 0; j < n; j += m*2)
			data[j+m] -= (sqrt(3.F)/4.F)*data[j] + ((sqrt(3.F)-2.F)/4.F)*data[j-m*2&n-1];
		for (int j = 0; j < n; j += m*2)
			data[j] -= data[j+m+m*2&n-1];
		for (int j = 0; j < n; j += m*2)
		{
			data[j] *= (sqrt(3.F)-1.F)/sqrt(2.F);
			data[j+m] = !approx ? data[j+m]*(sqrt(3.F)+1.F)/sqrt(2.F) : data[j];
		}
		m *= 2;
	}
}

// Daubechies D4 wavelet inverse transform
void d4Wavelet1(Vector &data)
{
	const int n = data.size();
	assert((n&n-1) == 0);

	// at each scale
	for (int m = n; m > 1;)
	{
		m /= 2;
		// lifting steps
		for (int j = 0; j < n; j += m*2)
		{
			data[j+m] *= (sqrt(3.F)-1.F)/sqrt(2.F);
			data[j] *= (sqrt(3.F)+1.F)/sqrt(2.F);
		}
		for (int j = 0; j < n; j += m*2)
			data[j] += data[j+m+m*2&n-1];
		for (int j = 0; j < n; j += m*2)
			data[j+m] += (sqrt(3.F)/4.F)*data[j] + ((sqrt(3.F)-2.F)/4.F)*data[j-m*2&n-1];
		for (int j = 0; j < n; j += m*2)
			data[j] -= sqrt(3.F)*data[j+m];
	}
}

// Hann window
constexpr real tau = ::acos(-1)*2;
void hann(Vector &d)
{
	const int n = d.size();

	// filter
	for (int i = 0; i < n; ++i)
		d[i] *= 0.5*(1.-cos(tau*i/n));
}
}
